import os
from re import X
from socket import AF_X25
import numpy as np
import matplotlib.pyplot as plt
import seaborn
import pandas as pd
import time   

seaborn.set(style = 'darkgrid')
# seaborn.set_style("whitegrid")
title_size = 20
lable_size = 18
titel_pad = 10
alpha = 0.1
window_length = 5

colors = ['red', 'purple', 'darkcyan', 'orchid', 'darkorange', 'green', 'royalblue']

envs = ['dmc_acrobot_swingup', 'dmc_cartpole_swingup', 'dmc_cheetah_run', 'dmc_finger_spin', 'dmc_hopper_hop', 
            'dmc_hopper_stand', 'dmc_quadruped_run', 'dmc_walker_stand', 'dmc_walker_run' , 'dmc_walker_walk',]

expert_performaces = {
    'dmc_acrobot_swingup': [420.37427946423423, 106.6341843447902],
    'dmc_cartpole_swingup': [858.5333001741556, 0.4942731563676342], 
    'dmc_cheetah_run': [890.2325014116416, 19.321439649226903], 
    'dmc_finger_spin': [976.4, 9.046546302318912], 
    'dmc_hopper_hop': [318.66631742075526, 7.362219804946876], 
    'dmc_hopper_stand': [939.4730723487439, 9.270191609032274],
    'dmc_quadruped_run': [547.3013652984198, 136.38115598687412],
    'dmc_walker_stand': [970.0151975860175, 20.19982066092738], 
    'dmc_walker_run': [778.1798692484144, 10.23481206581241], 
    'dmc_walker_walk': [961.3942600808514, 17.851356872831055], 
}

frames_dict = {
    'dmc_finger_spin': 1600000,
    'dmc_cartpole_swingup': 1600000,
    'dmc_hopper_stand': 1800000,
    'dmc_walker_walk': 1800000,
    'dmc_walker_stand': 1600000,
    'dmc_quadruped_run': 1600000,
    'dmc_cheetah_run': 1600000,
}

def env2name(env):
    env = env.split('_')
    env = env[1].title() + ' ' + env[2].title()
    return env

r_means = lambda x: np.nanmean(x, axis=1)
r_stderrs = lambda x: np.nanstd(x, axis=1) /  np.sqrt(np.count_nonzero(x, axis=1))
r_mins = lambda x: r_means(x) - r_stderrs(x)  # np.nanmin(x, axis=1)
r_maxs = lambda x: r_means(x) + r_stderrs(x)  # np.nanmax(x, axis=1)

title_size = 25
lable_size = 24
ticksize = 18
line_width = 2.5

plt.rcParams['xtick.labelsize']=ticksize
plt.rcParams['ytick.labelsize']=ticksize

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

def read_multi_files(folder_path, require_suffix=None, suffix = '', column='episode_reward', x_column='frame', frame_limit=np.Inf, readfrom='eval'): # Test Returns Mean
    min_len = np.Inf

    files = os.listdir(folder_path+suffix)
    reward_data = []
    tmp = []
    
    ana_file = open("./analysis.txt", "a")
    seed_dict = {x:0 for x in range(0,7)}

    for file in files:
        if (("finger_spin" in file) or ("cartpole_swingup") in file) and ("seed_3" in file) and ("patch" in file):
            seed_dict[3] += 1
            continue
        if (("cartpole_swingup") in file) and ("seed_0" in file) and ("patch" in file) and ("bc" not in file):
            seed_dict[0] += 1
            continue
        file_path = folder_path + suffix + '/' + file + '/{}.csv'.format(readfrom)
        df = pd.read_csv(file_path)

        update_time = time.ctime(os.stat(file_path).st_mtime)
        if df[x_column].max() < 1600000 and "bc" not in folder_path:
            ana_file.write(file + ', max step ' + str(df[x_column].max()) + ', last update' + update_time + '\n')
        # if df[x_column].max() < 1000000:
        #     continue
        file_seed = int(file[-1])
        seed_dict[file_seed] += 1

        if min_len > len(df):
            min_len = min(len(df), len(df[df[x_column]<=frame_limit]))
            x = df[x_column][:min_len]

        if column == "similarity":
            if "auto1.3" in file:
                df[column] /= 1.3
            elif "auto0.5" in file:
                df[column] /= 0.5
            elif "auto1.0" in file:
                pass

        tmp.append(df[column])

    for seed in seed_dict:
        if seed == 0:
            continue
        # if seed_dict[seed] < 1 and "bc" not in folder_path:
        #     ana_file.write(file + "lost seed " + str(seed) + "\n")
        
    for data in tmp:
        reward_data.append(data[:min_len])

    ana_file.close()

    return reward_data, x.values

def MA(value, step):
    ma_value = []
    for i in range(len(value)):
        if step > 1:
            if i < 5:
                tmp = value[i:i+int(step/1.5)]
            elif 5 <= i < 10:
                tmp = value[i:i + int(step/1.3)]
            elif 10 <= i < 15:
                tmp = value[i:i + int(step / 1.1)]
            else:
                tmp = value[i:i + step]
        else:
            tmp = [value[i]]
        if len(tmp) > 0:
            ma_value.append(sum(tmp) / len(tmp))
        else:
            ma_value.append(tmp[0])
    return ma_value


def draw_motivation(traj_num=10):
    fig, ax = plt.subplots(1, 3, figsize=(24, 5))
    plt.subplots_adjust(hspace=0.6, wspace=0.25)

    methods = ['Shared-Encoder', 'Independent-Encoder']
    names = ['encairl_ss_enclrlr3', 'noshare_encairl_ss']
    num_method = 2

    env = 'dmc_finger_spin'
    window_length = 5
    require_suffix = None
    expert_performace = 976.4, 9.046546302318912
    for i in range(num_method):
        name = names[i]+"_{}".format(traj_num)
    
        returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix)
        returns = np.array(returns).T
        
        y1 = r_mins(returns)
        y2 = r_maxs(returns)
        y1[np.isnan(y1)] = 0.0
        y2[np.isnan(y2)] = 0.0
        ma_y1 = MA(y1, window_length)
        ma_y2 = MA(y2, window_length)
        ax[0].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[i], interpolate=True, alpha=alpha)
        ax[0].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i])

    y = np.ones_like(x) * expert_performace[0]
    ma_y2 = y + expert_performace[1]
    ma_y1 = y - expert_performace[1]
    ax[0].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[-1], interpolate=True, alpha=alpha)
    ax[0].plot(x, y, color=colors[-1], linewidth=line_width, label="Expert", linestyle='dashed')

    ax[0].set_title(env2name(env), fontsize=title_size, pad=titel_pad)
    ax[0].set_xlabel('Frames', fontsize=lable_size)
    ax[0].set_ylabel('Episode Return', fontsize=lable_size)
    ax[0].ticklabel_format(style='sci', scilimits=(-1,2), axis='x')


    # env = 'dmc_hopper_hop'
    # expert_performace = 318.66631742075526, 7.362219804946876
    names = ['encairl_ss', 'noshare_encairl_ss']
    env = 'dmc_cheetah_run'
    expert_performace = 890.2325014116416, 19.321439649226903
    window_length = 5
    require_suffix = None
    for i in range(num_method):
        name = names[i]+"_{}".format(traj_num)
    
        returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix)
        returns = np.array(returns).T
        
        y1 = r_mins(returns)
        y2 = r_maxs(returns)
        y1[np.isnan(y1)] = 0.0
        y2[np.isnan(y2)] = 0.0
        ma_y1 = MA(y1, window_length)
        ma_y2 = MA(y2, window_length)
        ax[1].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[i], interpolate=True, alpha=alpha)
        ax[1].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i])

    y = np.ones_like(x) * expert_performace[0]
    ma_y2 = y + expert_performace[1]
    ma_y1 = y - expert_performace[1]
    ax[1].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[-1], interpolate=True, alpha=alpha)
    ax[1].plot(x, y, color=colors[-1], linewidth=line_width, label="Expert", linestyle='dashed')

    ax[1].set_title(env2name(env), fontsize=title_size, pad=titel_pad)
    ax[1].set_xlabel('Frames', fontsize=lable_size)
    ax[1].set_ylabel('Episode Return', fontsize=lable_size)
    ax[1].ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

    
    names = ['encairl_ss', 'noshare_encairl_ss']
    window_length = 5
    require_suffix = None
    # env = 'dmc_walker_walk'
    # expert_performace = 961.3942600808514, 17.851356872831055
    # env = 'dmc_walker_run'
    # expert_performace = 778.1798692484144, 10.23481206581241
    env = 'dmc_walker_stand'
    expert_performace = 970.0151975860175, 20.19982066092738
    for i in range(num_method):
        name = names[i]+"_{}".format(traj_num)
    
        returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix)
        returns = np.array(returns).T
        
        y1 = r_mins(returns)
        y2 = r_maxs(returns)
        y1[np.isnan(y1)] = 0.0
        y2[np.isnan(y2)] = 0.0
        ma_y1 = MA(y1, window_length)
        ma_y2 = MA(y2, window_length)
        ax[2].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[i], interpolate=True, alpha=alpha)
        ax[2].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i])

    y = np.ones_like(x) * expert_performace[0]
    ma_y2 = y + expert_performace[1]
    ma_y1 = y - expert_performace[1]
    ax[2].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[-1], interpolate=True, alpha=alpha)
    ax[2].plot(x, y, color=colors[-1], linewidth=line_width, label="Expert", linestyle='dashed')

    ax[2].set_title(env2name(env), fontsize=title_size, pad=titel_pad)
    ax[2].set_xlabel('Frames', fontsize=lable_size)
    ax[2].set_ylabel('Episode Return', fontsize=lable_size)
    ax[2].ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

    ax[2].legend(loc='center right', ncol=1, frameon=False, fontsize="x-large") # bbox_to_anchor=(0.7, 1.3), 


    plt.savefig("motivations_{}.png".format(traj_num), bbox_inches='tight')


def draw_motivation_single(traj_num=10):
    fig, ax = plt.subplots(1, 1, figsize=(8, 5))
    # plt.subplots_adjust(hspace=0.6, wspace=0.25)

    methods = ['Shared-Encoder', 'Independent-Encoder']
    # names = ['encairl_ss_enclrlr3', 'noshare_encairl_ss']
    names = ['encairl_ss', 'noshare_encairl_ss']
    num_method = 2

    env = 'dmc_finger_spin'
    window_length = 5
    require_suffix = None
    expert_performace = 976.4, 9.046546302318912
    frame_limit = frames_dict[env]
    for i in range(num_method):
        name = names[i]+"_{}".format(traj_num)
    
        returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix, frame_limit=frame_limit)
        returns = np.array(returns).T

        y1 = r_mins(returns)
        y2 = r_maxs(returns)
        y1[np.isnan(y1)] = 0.0
        y2[np.isnan(y2)] = 0.0
        ma_y1 = MA(y1, window_length)
        ma_y2 = MA(y2, window_length)
        ax.fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[i], interpolate=True, alpha=alpha)
        ax.plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i])

        for j in range(returns.shape[1]):
            ax.plot(x, returns[:,j], color=colors[i], linewidth=line_width, alpha=0.2)

    y = np.ones_like(x) * expert_performace[0]
    ma_y2 = y + expert_performace[1]
    ma_y1 = y - expert_performace[1]
    ax.fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[-1], interpolate=True, alpha=alpha)
    ax.plot(x, y, color=colors[-1], linewidth=line_width, label="Expert", linestyle='dashed')

    ax.set_title(env2name(env), fontsize=title_size, pad=titel_pad)
    ax.set_xlabel('Frames', fontsize=lable_size)
    ax.set_ylabel('Episode Return', fontsize=lable_size)
    ax.ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

    ax.legend(loc='center', bbox_to_anchor=(0.45,-0.25), ncol=3, frameon=False, fontsize="x-large", labelspacing=0.1, columnspacing=0.5) # bbox_to_anchor=(0.7, 1.3), 

    plt.savefig("motivation_{}.png".format(traj_num), bbox_inches='tight')
    plt.savefig("motivation_{}.pdf".format(traj_num), bbox_inches='tight')

def draw_main(traj_num=10):

    fig_row, fig_col = 2, 4
    fig, ax = plt.subplots(fig_row, fig_col, figsize=(6*fig_col, 5*fig_row))
    plt.subplots_adjust(hspace=0.4, wspace=0.33)

    envs = ['dmc_cartpole_swingup', 'dmc_cheetah_run', 'dmc_finger_spin', 'dmc_hopper_stand', 'dmc_walker_stand', 'dmc_walker_walk', 'dmc_quadruped_run'] 
            # 'dmc_quadruped_run', 'dmc_walker_stand', 'dmc_walker_run' , 'dmc_walker_walk',]
    
    methods = ['PatchAIL w.o. Reg. (Ours)', 'PatchAIL-Weight (Ours)', 'PatchAIL-Bonus (Ours)', 'Shared-Encoder AIL', 'Independent-Encoder AIL', 'BC']
    names = ['patchairl_ss', 'patchairl_ss_kldsimregauto1.3_buf100w_randomexpdemo_noaugsim', 'patchairl_ss_kldsimregaugauto0.5_buf100w_randomexpdemo_noaugsim', 'encairl_ss', 'noshare_encairl_ss', 'bc']
    
    if traj_num == 5:
        # methods = ['PatchAIL w.o. Reg. (Ours)', 'PatchAIL-Weight (Ours)', 'Shared-Encoder AIL', 'Independent-Encoder AIL', ] #, 'BC']
        names = ['patchairl_ss', 'patchairl_ss_kldsimregauto1.3_buf100w_randomexpdemo_noaugsim', 'patchairl_ss_kldsimregaugauto0.5_buf100w_randomexpdemo_noaugsim_rewscale0.5', 'encairl_ss', 'noshare_encairl_ss', 'bc']

    tab_results = {}
    
    num_method = len(methods)
    window_length = 2
    
    for f_i in range(fig_row):
        for f_j in range(fig_col):
            if f_i*fig_col+f_j >= len(envs):
                ax[f_i][f_j].set_visible(False)
                continue
            env = envs[f_i*fig_col+f_j]
            score_dict = {}
            require_suffix = None
            expert_performace = expert_performaces[env]
            max_x = 0
            x_axis = None
            frame_limit = np.Inf
            if env in frames_dict:
                frame_limit = frames_dict[env]
            for i in range(num_method):
                name = names[i]+"_{}".format(traj_num)
                # if env == 'dmc_cartpole_swingup' and i == 0:
                #     name = "patchairl_{}".format(traj_num)
                path_name = 'exp_local/'+env+'/'+name
                # if 'noshare' in name:
                #     path_name = 'exp_local/wrong_noshare/'+env+'/'+name

                returns, x = read_multi_files(path_name, require_suffix, frame_limit=frame_limit)
            
                returns = np.array(returns).T
                if max_x < len(x):
                    max_x = len(x)
                    x_axis = x
                
                if names[i] == "bc":
                    returns = np.array([[_[-1]] * max_x for _ in returns]).T
                    x = x_axis
                
                y1 = r_mins(returns)
                y2 = r_maxs(returns)
                y1[np.isnan(y1)] = 0.0
                y2[np.isnan(y2)] = 0.0
                ma_y1 = MA(y1, window_length)
                ma_y2 = MA(y2, window_length)

                # print(i)
                # print(np.shape(returns))
                # print(np.shape(x))
                # print(np.shape(ma_y1))
                # print(np.shape(ma_y2))

                score_dict[methods[i]] = r_means(returns), r_stderrs(returns)

                ax[f_i,f_j].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[i], interpolate=True, alpha=alpha)
                if "Patch" not in methods[i]:
                    ax[f_i,f_j].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i], linestyle='dashed')
                else:
                    ax[f_i,f_j].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i])

            y = np.ones(max_x) * expert_performace[0]
            ma_y2 = y + expert_performace[1]
            ma_y1 = y - expert_performace[1]
            ax[f_i,f_j].fill_between(x_axis, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[-1], interpolate=True, alpha=alpha)
            ax[f_i,f_j].plot(x_axis, y, color=colors[-1], linewidth=line_width, label="Expert", linestyle='dashed')

            ax[f_i,f_j].set_title(env2name(env), fontsize=title_size, pad=titel_pad)
            ax[f_i,f_j].set_xlabel('Frames', fontsize=lable_size)
            ax[f_i,f_j].set_ylabel('Episode Return', fontsize=lable_size)
            ax[f_i,f_j].ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

            tab_results[env] = score_dict

    # ax[f_i,f_j].legend(loc='center right', ncol=1, frameon=False, fontsize="x-large") # bbox_to_anchor=(0.7, 1.3), 
    ax[fig_row-1,fig_col-2].legend(loc='center right', ncol=1, frameon=False, fontsize="x-large", bbox_to_anchor=(2.3, 0.5))
    plt.savefig("curve_dmc_{}_smooth_{}.png".format(traj_num, window_length), bbox_inches='tight')
    plt.savefig("curve_dmc_{}_smooth_{}.pdf".format(traj_num, window_length), bbox_inches='tight')
    
    # print("500K step scores & PatchAIL w.o. Reg. & PatchAIL-Weight & PatchAIL-Bonus & Shared-Encoder AIL & Independent-Encoder AIL & BC & Expert \\\\")
    # print("\\hline")

    # for env in tab_results: 
    #     # print("Env: {}".format(env))
    #     score_dict = tab_results[env]
    #     # print("100 K results:")
    #     # for key in score_dict:
    #     #     print(key, score_dict[key][0][5],"$\pm$",score_dict[key][1][5])
        
    #     print("{} & {:.0f}$\pm${:.0f} & {:.0f}$\pm${:.0f} & {:.0f}$\pm${:.0f} & {:.0f}$\pm${:.0f} & {:.0f}$\pm${:.0f} & {:.0f}$\pm${:.0f} & {:.0f}$\pm${:.0f}\\\\".format(
    #         env2name(env),
    #         score_dict["PatchAIL w.o. Reg."][0][25], score_dict["PatchAIL w.o. Reg."][1][25], 
    #         score_dict["PatchAIL-Weight"][0][25], score_dict["PatchAIL-Weight"][1][25], 
    #         score_dict["PatchAIL-Bonus"][0][25], score_dict["PatchAIL-Bonus"][1][25], 
    #         score_dict["Shared-Encoder AIL"][0][25], score_dict["Shared-Encoder AIL"][1][25], 
    #         score_dict["Independent-Encoder AIL"][0][25], score_dict["Independent-Encoder AIL"][1][25], 
    #         score_dict["BC"][0][25], score_dict["BC"][1][25],
    #         expert_performaces[env][0], expert_performaces[env][1],
    #     ))


    # print("1M step scores & PatchAIL w.o. Reg. & PatchAIL-Weight & PatchAIL-Bonus & Shared-Encoder AIL & Independent-Encoder AIL & BC & Expert \\\\")
    # print("\\hline")

    # for env in tab_results: 
    #     score_dict = tab_results[env]
        
    #     print("{} & {:.0f}$\pm${:.0f} & {:.0f}$\pm${:.0f} & {:.0f}$\pm${:.0f} & {:.0f}$\pm${:.0f} & {:.0f}$\pm${:.0f} & {:.0f}$\pm${:.0f} & {:.0f}$\pm${:.0f}\\\\".format(
    #         env2name(env),
    #         score_dict["PatchAIL w.o. Reg."][0][50], score_dict["PatchAIL w.o. Reg."][1][50], 
    #         score_dict["PatchAIL-Weight"][0][50], score_dict["PatchAIL-Weight"][1][50], 
    #         score_dict["PatchAIL-Bonus"][0][50], score_dict["PatchAIL-Bonus"][1][50], 
    #         score_dict["Shared-Encoder AIL"][0][50], score_dict["Shared-Encoder AIL"][1][50], 
    #         score_dict["Independent-Encoder AIL"][0][50], score_dict["Independent-Encoder AIL"][1][50], 
    #         score_dict["BC"][0][50], score_dict["BC"][1][50],
    #         expert_performaces[env][0], expert_performaces[env][1],
    #     ))


def reliable_plot(traj_num):
    from rliable import library as rly
    from rliable import metrics
    from rliable import plot_utils

    fig_row, fig_col = 2, 4
    fig, ax = plt.subplots(fig_row, fig_col, figsize=(6*fig_col, 5*fig_row))
    plt.subplots_adjust(hspace=0.6, wspace=0.3)

    envs = ['dmc_cartpole_swingup', 'dmc_cheetah_run', 'dmc_finger_spin', 'dmc_hopper_stand', 'dmc_walker_stand', 'dmc_walker_walk', 'dmc_quadruped_run'] 
            # 'dmc_quadruped_run', 'dmc_walker_stand', 'dmc_walker_run' , 'dmc_walker_walk',]

    methods = ['PatchAIL w.o. Reg.', 'PatchAIL-Weight', 'PatchAIL-Bonus', 'Shared-Encoder AIL', 'Independent-Encoder AIL', ] #, 'BC']
    names = ['patchairl_ss', 'patchairl_ss_kldsimregauto1.3_buf100w_randomexpdemo_noaugsim', 'patchairl_ss_kldsimregaugauto0.5_buf100w_randomexpdemo_noaugsim', 'encairl_ss', 'noshare_encairl_ss', 'bc']
    num_method = len(methods)
    window_length = 1
    
    for f_i in range(fig_row):
        for f_j in range(fig_col):
            if f_i*fig_col+f_j >= len(envs):
                ax[f_i][f_j].set_visible(False)
                continue
            env = envs[f_i*fig_col+f_j]

            require_suffix = None
            expert_performace = expert_performaces[env]
            max_x = 0
            x_axis = None
            ale_all_frames_scores_dict = {}
            frame_limit = np.Inf
            if env in frames_dict:
                frame_limit = frames_dict[env]
            for i in range(num_method):
                name = names[i]+"_{}".format(traj_num)
                # if env == 'dmc_cartpole_swingup' and i==2:
                #     name = 'patchairl_ss_kldsimreg1.0_buf100w_randomexpdemo_noaugsim'+"_{}".format(traj_num)
                # if env == 'dmc_finger_spin' and i==2:
                #     name = 'patchairl_ss_kldsimreg1.5_buf100w_randomexpdemo_noaugsim'+"_{}".format(traj_num)
                
                path_name = 'exp_local/'+env+'/'+name
                if 'noshare' in name:
                    path_name = 'exp_local/wrong_noshare/'+env+'/'+name

                returns, x = read_multi_files(path_name, require_suffix, frame_limit=frame_limit)
                returns = np.array(returns)

                if max_x < len(x):
                    max_x = len(x)
                    x_axis = x
                
                # if names[i] == "bc":
                #     returns = np.array([[_[-1]] * max_x for _ in returns])
                #     x = x_axis

                chosen_frames = range(0, returns.shape[-1], 5)

                ale_all_frames_scores_dict[methods[i]] = returns[...,chosen_frames]
                frames_dict[methods[i]] = x[chosen_frames]
            
            ale_frames_scores_dict = {algorithm: score for algorithm, score
                                    in ale_all_frames_scores_dict.items()}
            # `(num_runs x frames)` where scores are recorded every million frame.
    
            iqm = lambda scores: np.array([metrics.aggregate_iqm(scores[..., frame])
                                        for frame in range(scores.shape[-1])])
            iqm_scores, iqm_cis = rly.get_interval_estimates(
                ale_frames_scores_dict, iqm, reps=50000)

            # print(iqm_scores[methods[0]].shape, iqm_cis[methods[0]].shape)

            for i, algorithm in enumerate(methods):
                metric_values = iqm_scores[algorithm]
                lower, upper = iqm_cis[algorithm]
                # ax[f_i,f_j].plot(
                #     frames_dict[algorithm],
                #     metric_values,
                #     color=colors[i],
                #     linewidth=line_width,
                #     label=algorithm)
                # ax[f_i,f_j].fill_between(
                #     frames_dict[algorithm], 
                #     y1=lower,
                #     y2=upper,
                #     color=colors[i], alpha=0.2)
                ax[f_i,f_j].plot(
                    frames_dict[algorithm],
                    MA(metric_values, window_length),
                    color=colors[i],
                    linewidth=line_width,
                    marker="o",
                    label=algorithm)
                ax[f_i,f_j].fill_between(
                    frames_dict[algorithm], 
                    y1=MA(lower, window_length), 
                    y2=MA(upper, window_length), 
                    color=colors[i], alpha=0.2)
            
            # Plot BC
            name = "bc_{}".format(traj_num)
            returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix, frame_limit=frame_limit)
            returns = np.array(returns)
            returns = np.array([[_[-1]] * max_x for _ in returns])

            y1 = r_mins(returns.T)
            y2 = r_maxs(returns.T)
            y1[np.isnan(y1)] = 0.0
            y2[np.isnan(y2)] = 0.0

            ax[f_i,f_j].fill_between(x_axis, y1, y2, where=y2 >= y1, facecolor=colors[num_method], interpolate=True, alpha=alpha)
            ax[f_i,f_j].plot(x_axis, r_means(returns.T), color=colors[num_method], linewidth=line_width, label="BC")

            # Plot Expert

            y = np.ones(max_x) * expert_performace[0]
            ma_y2 = y + expert_performace[1]
            ma_y1 = y - expert_performace[1]
            ax[f_i,f_j].fill_between(x_axis, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[-1], interpolate=True, alpha=alpha)
            ax[f_i,f_j].plot(x_axis, y, color=colors[-1], linewidth=line_width, label="Expert", linestyle='dashed')

            ax[f_i,f_j].set_title(env2name(env), fontsize=title_size, pad=titel_pad)
            ax[f_i,f_j].set_xlabel('Frames', fontsize=lable_size)
            ax[f_i,f_j].set_ylabel('Episode Return', fontsize=lable_size)
            ax[f_i,f_j].ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

    ax[fig_row-1,fig_col-2].legend(loc='center right', ncol=1, frameon=False, fontsize="x-large", bbox_to_anchor=(2.3, 0.5))
    plt.savefig("iqm_curve_dmc_{}_smooth_{}.png".format(traj_num, window_length), bbox_inches='tight')


def reliable_plot_noisy(traj_num):
    from rliable import library as rly
    from rliable import metrics
    from rliable import plot_utils
    import scipy


    envs = ['dmc_finger_spin', 'dmc_quadruped_run'] 

    methods = ['PatchAIL-Weight', 'Shared-Encoder AIL', 'Independent-Encoder AIL'] #, 'BC']
    names = ['noisy_patchairl_ss_kldsimregauto1.3_buf100w_randomexpdemo_noaugsim', 'noisy28_encairl_ss', 'noisy28_noshare_encairl_ss', 'noise_bc']
    num_method = len(methods)
    window_length = 1

    fig_row, fig_col = 1, len(envs)
    fig, ax = plt.subplots(fig_row, fig_col, figsize=(6*fig_col, 5*fig_row))
    plt.subplots_adjust(hspace=0.6, wspace=0.3)
    
    for f_i in range(fig_row):
        for f_j in range(fig_col):
            # if f_i*fig_col+f_j >= len(envs):
            #     ax[f_i][f_j].set_visible(False)
            #     continue
            env = envs[f_i*fig_col+f_j]
            require_suffix = None
            expert_performace = expert_performaces[env]
            max_x = 0
            x_axis = None
            ale_all_frames_scores_dict = {}
            trim_mean = {}
            frame_limit = np.Inf
            if env in frames_dict:
                frame_limit = frames_dict[env]
            for i in range(num_method):
                name = names[i]+"_{}".format(traj_num)
                
                path_name = 'exp_local/'+env+'/'+name

                returns, x = read_multi_files(path_name, require_suffix, frame_limit=frame_limit)

                returns = np.array(returns)

                if max_x < len(x):
                    max_x = len(x)
                    x_axis = x
                
                # if names[i] == "bc":
                #     returns = np.array([[_[-1]] * max_x for _ in returns])
                #     x = x_axis

                ale_all_frames_scores_dict[methods[i]] = returns
                frames_dict[methods[i]] = x
                
            ale_frames_scores_dict = {algorithm: score for algorithm, score
                                    in ale_all_frames_scores_dict.items()}
            # `(num_runs x frames)` where scores are recorded every million frame.
    
            iqm = lambda scores: np.array([metrics.aggregate_iqm(scores[..., frame])
                                        for frame in range(scores.shape[-1])])
            iqm_scores, iqm_cis = rly.get_interval_estimates(
                ale_frames_scores_dict, iqm, reps=50000)
            

            # print(iqm_scores[methods[0]].shape, iqm_cis[methods[0]].shape)

            for i, algorithm in enumerate(methods):
                metric_values = iqm_scores[algorithm]
                lower, upper = iqm_cis[algorithm]
                ax[f_i*fig_col+f_j].plot(
                    frames_dict[algorithm],
                    metric_values,
                    color=colors[i],
                    linewidth=line_width,
                    label=algorithm
                )
                ax[f_i*fig_col+f_j].fill_between(
                    frames_dict[algorithm], 
                    y1=lower,
                    y2=upper,
                    color=colors[i], 
                    alpha=0.2
                )
            
            # Plot BC
            name = "noise_bc_{}".format(traj_num)
            returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix, frame_limit=frame_limit)
            returns = np.array(returns)
            returns = np.array([[_[-1]] * max_x for _ in returns])

            y1 = r_mins(returns.T)
            y2 = r_maxs(returns.T)
            y1[np.isnan(y1)] = 0.0
            y2[np.isnan(y2)] = 0.0

            ax[f_i*fig_col+f_j].fill_between(x_axis, y1, y2, where=y2 >= y1, facecolor=colors[num_method], interpolate=True, alpha=alpha)
            ax[f_i*fig_col+f_j].plot(x_axis, r_means(returns.T), color=colors[num_method], linewidth=line_width, label="BC")

            # Plot Expert

            y = np.ones(max_x) * expert_performace[0]
            ma_y2 = y + expert_performace[1]
            ma_y1 = y - expert_performace[1]
            ax[f_i*fig_col+f_j].fill_between(x_axis, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[-1], interpolate=True, alpha=alpha)
            ax[f_i*fig_col+f_j].plot(x_axis, y, color=colors[-1], linewidth=line_width, label="Expert", linestyle='dashed')

            ax[f_i*fig_col+f_j].set_title(env2name(env), fontsize=title_size, pad=titel_pad)
            ax[f_i*fig_col+f_j].set_xlabel('Frames', fontsize=lable_size)
            ax[f_i*fig_col+f_j].set_ylabel('Episode Return', fontsize=lable_size)
            ax[f_i*fig_col+f_j].ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

    ax[0].legend(loc='lower center', ncol=3, frameon=False, fontsize="x-large", bbox_to_anchor=(1.15, -0.45))
    plt.savefig("iqm_noisy_curve_dmc_{}.png".format(traj_num), bbox_inches='tight')

def draw_noisy(traj_num=10):

    envs = ['dmc_finger_spin', 'dmc_quadruped_run'] 

    methods = ['PatchAIL-Weight', 'Shared-Encoder AIL', 'Independent-Encoder AIL'] #, 'BC']
    names = ['noisy_patchairl_ss_kldsimregauto1.3_buf100w_randomexpdemo_noaugsim', 'noisy28_encairl_ss', 'noisy28_noshare_encairl_ss', 'noise_bc']
    num_method = len(methods)
    window_length = 2

    fig_row, fig_col = 1, len(envs)
    fig, ax = plt.subplots(fig_row, fig_col, figsize=(6*fig_col, 5*fig_row))
    plt.subplots_adjust(hspace=0.6, wspace=0.3)
    
    for f_i in range(fig_row):
        for f_j in range(fig_col):
            # if f_i*fig_col+f_j >= len(envs):
            #     ax[f_i][f_j].set_visible(False)
            #     continue
            env = envs[f_i*fig_col+f_j]
            require_suffix = None
            expert_performace = expert_performaces[env]
            max_x = 0
            x_axis = None
            frame_limit = np.Inf
            if env in frames_dict:
                frame_limit = frames_dict[env]
            for i in range(num_method):
                name = names[i]+"_{}".format(traj_num)
                path_name = 'exp_local/'+env+'/'+name

                returns, x = read_multi_files(path_name, require_suffix, frame_limit=frame_limit)
            
                returns = np.array(returns).T
                if max_x < len(x):
                    max_x = len(x)
                    x_axis = x
                
                y1 = r_mins(returns)
                y2 = r_maxs(returns)
                y1[np.isnan(y1)] = 0.0
                y2[np.isnan(y2)] = 0.0
                ma_y1 = MA(y1, window_length)
                ma_y2 = MA(y2, window_length)

                # print(i)
                # print(np.shape(returns))
                # print(np.shape(x))
                # print(np.shape(ma_y1))
                # print(np.shape(ma_y2))

                ax[f_i*fig_col+f_j].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[i], interpolate=True, alpha=alpha)
                ax[f_i*fig_col+f_j].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i])

            # Plot BC
            name = "noise_bc_{}".format(traj_num)
            returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix, frame_limit=frame_limit)
            returns = np.array(returns)
            returns = np.array([[_[-1]] * max_x for _ in returns])

            y1 = r_mins(returns.T)
            y2 = r_maxs(returns.T)
            y1[np.isnan(y1)] = 0.0
            y2[np.isnan(y2)] = 0.0

            ax[f_i*fig_col+f_j].fill_between(x_axis, y1, y2, where=y2 >= y1, facecolor=colors[num_method], interpolate=True, alpha=alpha)
            ax[f_i*fig_col+f_j].plot(x_axis, r_means(returns.T), color=colors[num_method], linewidth=line_width, label="BC")

            # Plot Expert

            y = np.ones(max_x) * expert_performace[0]
            ma_y2 = y + expert_performace[1]
            ma_y1 = y - expert_performace[1]
            ax[f_i*fig_col+f_j].fill_between(x_axis, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[-1], interpolate=True, alpha=alpha)
            ax[f_i*fig_col+f_j].plot(x_axis, y, color=colors[-1], linewidth=line_width, label="Expert", linestyle='dashed')

            ax[f_i*fig_col+f_j].set_title(env2name(env), fontsize=title_size, pad=titel_pad)
            ax[f_i*fig_col+f_j].set_xlabel('Frames', fontsize=lable_size)
            ax[f_i*fig_col+f_j].set_ylabel('Episode Return', fontsize=lable_size)
            ax[f_i*fig_col+f_j].ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

    ax[0].legend(loc='lower center', ncol=3, frameon=False, fontsize="x-large", bbox_to_anchor=(1.15, -0.45))
    plt.savefig("noisy_curve_dmc_{}.png".format(traj_num), bbox_inches='tight')


def draw_patch_ablation(traj_num=10):

    envs = ['dmc_cheetah_run', 'dmc_finger_spin'] 
    methods = ['84 x 84 Patches (Pixels)', '46 x 46 Patches', '42 x 42 Patches', '39 x 39 Patches (default)', '35 x 35 Patches', '25 x 25 Patches']
    names = ['iid_patchairl_ss', 'small_ks_2_patchairl_ss', 'small_ks_3_patchairl_ss', 'patchairl_ss', 'big_ks_5_patchairl_ss', 'big_ks_8_patchairl_ss']
    num_method = len(methods)

    fig_row, fig_col = 1, len(envs)
    fig, ax = plt.subplots(fig_row, fig_col, figsize=(6*fig_col, 5*fig_row))
    plt.subplots_adjust(hspace=0.6, wspace=0.3)

    window_length = 2
    require_suffix = None
    for f_i in range(fig_row):
        for f_j in range(fig_col):
            # if f_i*fig_col+f_j >= len(envs):
            #     ax[f_i][f_j].set_visible(False)
            #     continue
            env = envs[f_i*fig_col+f_j]
            expert_performace = expert_performaces[env]
            frame_limit = frames_dict[env]
            for i in range(num_method):
                name = names[i]+"_{}".format(traj_num)
            
                returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix, frame_limit=frame_limit)
                returns = np.array(returns).T

                y1 = r_mins(returns)
                y2 = r_maxs(returns)
                y1[np.isnan(y1)] = 0.0
                y2[np.isnan(y2)] = 0.0
                ma_y1 = MA(y1, window_length)
                ma_y2 = MA(y2, window_length)
                ax[f_i*fig_col+f_j].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[i], interpolate=True, alpha=alpha)
                ax[f_i*fig_col+f_j].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i])

                # for j in range(returns.shape[1]):
                #     ax.plot(x, returns[:,j], color=colors[i], linewidth=line_width, alpha=0.2)

            y = np.ones_like(x) * expert_performace[0]
            ma_y2 = y + expert_performace[1]
            ma_y1 = y - expert_performace[1]
            ax[f_i*fig_col+f_j].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[-1], interpolate=True, alpha=alpha)
            ax[f_i*fig_col+f_j].plot(x, y, color=colors[-1], linewidth=line_width, label="Expert", linestyle='dashed')

            ax[f_i*fig_col+f_j].set_title(env2name(env), fontsize=title_size, pad=titel_pad)
            ax[f_i*fig_col+f_j].set_xlabel('Frames', fontsize=lable_size)
            ax[f_i*fig_col+f_j].set_ylabel('Episode Return', fontsize=lable_size)
            ax[f_i*fig_col+f_j].ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

            # ax[f_i*fig_col+f_j].legend(loc='center', bbox_to_anchor=(0.45,-0.35), ncol=2, frameon=False, fontsize="x-large", labelspacing=0.1, columnspacing=0.5) # bbox_to_anchor=(0.7, 1.3), 
        # ax[0].legend(loc='lower center', ncol=4, frameon=False, fontsize="x-large", bbox_to_anchor=(1.15, -0.4))
    ax[1].legend(loc='upper right', ncol=1, frameon=False, fontsize="x-large", bbox_to_anchor=(2.0, 1.0))

    plt.savefig("patch_ablation_{}.png".format(traj_num), bbox_inches='tight')
    plt.savefig("patch_ablation_{}.pdf".format(traj_num), bbox_inches='tight')


def draw_aggr_ablation(traj_num=10):
    envs = ['dmc_cheetah_run', 'dmc_finger_spin'] 
    methods = ['min', 'max', 'median', 'mean (default)']
    names = ['patchairl_ss_min', 'patchairl_ss_max', 'patchairl_ss_median', 'patchairl_ss']
    num_method = len(methods)

    fig_row, fig_col = 1, len(envs)
    fig, ax = plt.subplots(fig_row, fig_col, figsize=(6*fig_col, 5*fig_row))
    plt.subplots_adjust(hspace=0.6, wspace=0.3)
    
    window_length = 2

    for f_i in range(fig_row):
        for f_j in range(fig_col):
            # if f_i*fig_col+f_j >= len(envs):
            #     ax[f_i][f_j].set_visible(False)
            #     continue
            env = envs[f_i*fig_col+f_j]
            require_suffix = None
            expert_performace = expert_performaces[env]
            max_x = 0
            x_axis = None
            frame_limit = np.Inf
            if env in frames_dict:
                frame_limit = frames_dict[env]

            for i in range(num_method):
                name = names[i]+"_{}".format(traj_num)
            
                returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix, frame_limit=frame_limit)
                returns = np.array(returns).T
                
                y1 = r_mins(returns)
                y2 = r_maxs(returns)
                y1[np.isnan(y1)] = 0.0
                y2[np.isnan(y2)] = 0.0
                ma_y1 = MA(y1, window_length)
                ma_y2 = MA(y2, window_length)
                ax[f_i*fig_col+f_j].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[i], interpolate=True, alpha=alpha)
                ax[f_i*fig_col+f_j].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i])

                # for j in range(returns.shape[1]):
                #     ax.plot(x, returns[:,j], color=colors[i], linewidth=line_width, alpha=0.2)

            y = np.ones_like(x) * expert_performace[0]
            ma_y2 = y + expert_performace[1]
            ma_y1 = y - expert_performace[1]
            ax[f_i*fig_col+f_j].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[-1], interpolate=True, alpha=alpha)
            ax[f_i*fig_col+f_j].plot(x, y, color=colors[-1], linewidth=line_width, label="Expert", linestyle='dashed')

            ax[f_i*fig_col+f_j].set_title(env2name(env), fontsize=title_size, pad=titel_pad)
            ax[f_i*fig_col+f_j].set_xlabel('Frames', fontsize=lable_size)
            ax[f_i*fig_col+f_j].set_ylabel('Episode Return', fontsize=lable_size)
            ax[f_i*fig_col+f_j].ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

    # ax.legend(loc='center', bbox_to_anchor=(0.45,-0.25), ncol=3, frameon=False, fontsize="x-large", labelspacing=0.1, columnspacing=0.5) # bbox_to_anchor=(0.7, 1.3), 
    # ax[1].legend(loc='upper right', ncol=1, frameon=False, fontsize="x-large", bbox_to_anchor=(1.7, 1.0))
    ax[0].legend(loc='upper right', ncol=1, frameon=False, fontsize="x-large", bbox_to_anchor=(0.7, 0.96))
    # ax[1].legend(loc='lower center', ncol=5, frameon=False, fontsize="x-large", bbox_to_anchor=(1.15, -0.35))

    plt.savefig("aggregation_ablation_{}.png".format(traj_num), bbox_inches='tight')
    plt.savefig("aggregation_ablation_{}.pdf".format(traj_num), bbox_inches='tight')

def draw_bcaug(traj_num=1):

    envs = ['dmc_cartpole_swingup', 'dmc_cheetah_run', 'dmc_finger_spin', 'dmc_hopper_stand', 'dmc_walker_stand', 'dmc_walker_walk', 'dmc_quadruped_run'] 
    
    methods = ['PatchAIL', 'PatchAIL-W', 'PatchAIL-B', 'ROT', 'Shared-Encoder AIL', 'Independent-Encoder AIL']
    names = ['patchairl_ss_bc', 'patchairl_ss_kldsimregauto1.3_bc', 'patchairl_ss_kldsimregaugauto0.5_bc', 'pot', 'encairl_ss_bc', 'noshare_encairl_ss_bc']
    num_method = len(methods)

    fig_row, fig_col = 2, 4
    fig, ax = plt.subplots(fig_row, fig_col, figsize=(6*fig_col, 5*fig_row))
    plt.subplots_adjust(hspace=0.4, wspace=0.33)
    
    window_length = 2

    for f_i in range(fig_row):
        for f_j in range(fig_col):
            if f_i*fig_col+f_j >= len(envs):
                ax[f_i][f_j].set_visible(False)
                continue
            env = envs[f_i*fig_col+f_j]
            require_suffix = None
            expert_performace = expert_performaces[env]
            max_x = 0
            x_axis = None
            frame_limit = 1010000

            for i in range(num_method):
                if env == 'dmc_cheetah_run' and names[i] == "patchairl_ss_kldsimregauto1.3_bc":
                    continue
                name = names[i]+"_{}".format(traj_num)
            
                returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix, frame_limit=frame_limit)
                returns = np.array(returns).T
                
                y1 = r_mins(returns)
                y2 = r_maxs(returns)
                y1[np.isnan(y1)] = 0.0
                y2[np.isnan(y2)] = 0.0
                ma_y1 = MA(y1, window_length)
                ma_y2 = MA(y2, window_length)
                ax[f_i,f_j].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[i], interpolate=True, alpha=alpha)
                if "Patch" not in methods[i]:
                    ax[f_i,f_j].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i], linestyle='dashed')
                else:
                    ax[f_i,f_j].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i])

                # for j in range(returns.shape[1]):
                #     ax.plot(x, returns[:,j], color=colors[i], linewidth=line_width, alpha=0.2)

            y = np.ones((frame_limit,)) * expert_performace[0]
            ma_y2 = y + expert_performace[1]
            ma_y1 = y - expert_performace[1]
            x = np.arange(frame_limit)
            ax[f_i,f_j].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[-1], interpolate=True, alpha=alpha)
            ax[f_i,f_j].plot(x, y, color=colors[-1], linewidth=line_width, label="Expert", linestyle='dashed')

            ax[f_i,f_j].set_title(env2name(env), fontsize=title_size, pad=titel_pad)
            ax[f_i,f_j].set_xlabel('Frames', fontsize=lable_size)
            ax[f_i,f_j].set_ylabel('Episode Return', fontsize=lable_size)
            ax[f_i,f_j].ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

    ax[fig_row-1,fig_col-2].legend(loc='center right', ncol=1, frameon=False, fontsize="x-large", bbox_to_anchor=(2.3, 0.5))

    plt.savefig("bcaug_{}.png".format(traj_num), bbox_inches='tight')
    plt.savefig("bcaug_{}.pdf".format(traj_num), bbox_inches='tight')

def draw_observation_choice(traj_num=10):
    envs = ['dmc_cartpole_swingup', 'dmc_cheetah_run', 'dmc_finger_spin', 'dmc_hopper_stand'] # , 'dmc_walker_run', 'dmc_walker_walk', 'dmc_quadruped_run'] 

    methods = ['PatchAIL w.o Reg-D(s,s)', 'PatchAIL w.o Reg-D(s)', 'Shared-Encoder AIL-D(s,s)', 'Shared-Encoder AIL-D(s)']
    names = ['patchairl_ss', 'patchairl', 'encairl_ss', 'encairl']
    num_method = len(methods)

    # fig_row, fig_col = 2, 4
    fig_row, fig_col = 1, 4
    fig, ax = plt.subplots(fig_row, fig_col, figsize=(6*fig_col, 4*fig_row))
    plt.subplots_adjust(hspace=0.45, wspace=0.35)
    
    window_length = 2

    for f_i in range(fig_row):
        for f_j in range(fig_col):
            if f_i*fig_col+f_j >= len(envs):
                ax[f_i][f_j].set_visible(False)
                continue
            env = envs[f_i*fig_col+f_j]
            require_suffix = None
            expert_performace = expert_performaces[env]
            max_x = 0
            x_axis = None
            frame_limit = np.Inf
            if env in frames_dict:
                frame_limit = frames_dict[env]

            for i in range(num_method):
                name = names[i]+"_{}".format(traj_num)
            
                returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix, frame_limit=frame_limit)
                returns = np.array(returns).T

                y1 = r_mins(returns)
                y2 = r_maxs(returns)
                y1[np.isnan(y1)] = 0.0
                y2[np.isnan(y2)] = 0.0
                ma_y1 = MA(y1, window_length)
                ma_y2 = MA(y2, window_length)
                ax[f_i*fig_col+f_j].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[i], interpolate=True, alpha=alpha)
                if "Patch" not in methods[i]:
                    ax[f_i*fig_col+f_j].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i], linestyle='dashed')
                else:
                    ax[f_i*fig_col+f_j].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i])

                # for j in range(returns.shape[1]):
                #     ax.plot(x, returns[:,j], color=colors[i], linewidth=line_width, alpha=0.2)

            y = np.ones_like(x) * expert_performace[0]
            ma_y2 = y + expert_performace[1]
            ma_y1 = y - expert_performace[1]
            ax[f_i*fig_col+f_j].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[-1], interpolate=True, alpha=alpha)
            ax[f_i*fig_col+f_j].plot(x, y, color=colors[-1], linewidth=line_width, label="Expert", linestyle='dashed')

            ax[f_i*fig_col+f_j].set_title(env2name(env), fontsize=title_size, pad=titel_pad)
            ax[f_i*fig_col+f_j].set_xlabel('Frames', fontsize=lable_size)
            ax[f_i*fig_col+f_j].set_ylabel('Episode Return', fontsize=lable_size)
            ax[f_i*fig_col+f_j].ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

    # ax[fig_row-1,fig_col-2].legend(loc='center right', ncol=1, frameon=False, fontsize="x-large", bbox_to_anchor=(2.3, 0.5))
    ax[1].legend(loc='lower center', ncol=5, frameon=False, fontsize="x-large", bbox_to_anchor=(1.1, -0.45))

    plt.savefig("observation_ablation_{}.png".format(traj_num), bbox_inches='tight')
    plt.savefig("observation_ablation_{}.pdf".format(traj_num), bbox_inches='tight')

def draw_observation_choice_in_one(traj_num=10):
    envs = ['dmc_cartpole_swingup', 'dmc_cheetah_run', 'dmc_finger_spin', 'dmc_hopper_stand', 'dmc_walker_run', 'dmc_walker_walk', 'dmc_quadruped_run'] 

    methods = ['D(s,s)', 'D(s)']
    names = ['encairl_ss', 'encairl']
    num_method = len(methods)

    fig_row, fig_col = 1,1
    fig, ax = plt.subplots(fig_row, fig_col, figsize=(6*fig_col, 5*fig_row))
    plt.subplots_adjust(hspace=0.6, wspace=0.3)
    
    window_length = 2
    all_frames_scores_dict = {i:[] for i in methods}

    for idx in range(len(envs)):
        env = envs[idx]
        require_suffix = None
        expert_performace = expert_performaces[env]
        max_x = 0
        x_axis = None
        frame_limit = np.Inf
        if env in frames_dict:
            frame_limit = 1600000

        for i in range(num_method):
            name = names[i]+"_{}".format(traj_num)
        
            returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix, frame_limit=frame_limit)
            returns = np.array(returns).T

            all_frames_scores_dict[methods[i]].append(r_means(returns))
    
    for i in range(num_method):
        all_frames_scores_dict[methods[i]] = np.stack(all_frames_scores_dict[methods[i]])
        returns = r_means(all_frames_scores_dict[methods[i]])
        y1 = r_mins(returns)
        y2 = r_maxs(returns)
        y1[np.isnan(y1)] = 0.0
        y2[np.isnan(y2)] = 0.0
        ma_y1 = MA(y1, window_length)
        ma_y2 = MA(y2, window_length)

        ax.fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[i], interpolate=True, alpha=alpha)
        ax.plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i])
        ax.set_title(env2name(env), fontsize=title_size, pad=titel_pad)
        ax.set_xlabel('Frames', fontsize=lable_size)
        ax.set_ylabel('Episode Return', fontsize=lable_size)
        ax.ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

        ax.legend(loc='center right', ncol=3, frameon=False, fontsize="x-large", bbox_to_anchor=(2.3, 0.5))

    plt.savefig("observation_ablation_{}.png".format(traj_num), bbox_inches='tight')
    plt.savefig("observation_ablation_{}.pdf".format(traj_num), bbox_inches='tight')

def draw_aug_ablation(traj_num=10):
    envs = ['dmc_finger_spin']

    methods = ['Random-Shift (default)', 'Random-Cutout', 'Random-Crop-Resize', 'RandAugment']
    names = [['encairl_ss', 'encairl_ss_disc_randomcut', 'encairl_ss_disc_randomcrop', 'encairl_ss_disc_random_aug'], \
        ['noshare_encairl_ss', 'noshare_encairl_ss_disc_randomcut', 'noshare_encairl_ss_disc_randomcrop', 'noshare_encairl_ss_disc_random_aug'],\
        ['patchairl_ss', 'patchairl_ss_random_cutout', 'patchairl_ss_randomcrop', 'patchairl_ss_random_aug'],\
        ['patchairl_ss_kldsimregauto1.3_buf100w_randomexpdemo_noaugsim', 'patchairl_ss_kldsimregauto1.3_random_cutout', 'patchairl_ss_kldsimregauto1.3_randomcrop', 'patchairl_ss_kldsimregauto1.3_random_aug'],\
        ['patchairl_ss_kldsimregaugauto0.5_buf100w_randomexpdemo_noaugsim', 'patchairl_ss_kldsimregaugauto0.5_random_cutout', 'patchairl_ss_kldsimregaugauto0.5_randomcrop', 'patchairl_ss_kldsimregaugauto0.5_random_aug']]
    
    num_method = len(methods)

    fig_row, fig_col = 1,5
    fig, ax = plt.subplots(fig_row, fig_col, figsize=(6*fig_col, 5*fig_row))
    plt.subplots_adjust(hspace=0.6, wspace=0.3)
    
    window_length = 2
    all_frames_scores_dict = {i:[] for i in methods}

    env = envs[0]
    require_suffix = None
    expert_performace = expert_performaces[env]
    max_x = 0
    x_axis = None
    frame_limit = frames_dict[env]
    
    for k in range(len(names)):
        for i in range(num_method):
            name = names[k][i]+"_{}".format(traj_num)
                
            returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix, frame_limit=frame_limit)
            returns = np.array(returns).T

            y1 = r_mins(returns)
            y2 = r_maxs(returns)
            y1[np.isnan(y1)] = 0.0
            y2[np.isnan(y2)] = 0.0
            ma_y1 = MA(y1, window_length)
            ma_y2 = MA(y2, window_length)

            ax[k].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[i], interpolate=True, alpha=alpha)
            ax[k].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i])
            ax[k].set_xlabel('Frames', fontsize=lable_size)
            ax[k].set_ylabel('Episode Return', fontsize=lable_size)
            ax[k].ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

        y = np.ones_like(x) * expert_performace[0]
        ma_y2 = y + expert_performace[1]
        ma_y1 = y - expert_performace[1]
        ax[k].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[-1], interpolate=True, alpha=alpha)
        ax[k].plot(x, y, color=colors[-1], linewidth=line_width, label="Expert", linestyle='dashed')

    ax[0].set_title('Share-Encoder AIL', fontsize=title_size, pad=titel_pad)
    ax[1].set_title('Independent-Encoder AIL', fontsize=title_size, pad=titel_pad)
    ax[2].set_title('Patch AIL w.o. Reg', fontsize=title_size, pad=titel_pad)
    ax[3].set_title('PatchAIL-Weight', fontsize=title_size, pad=titel_pad)
    ax[4].set_title('PatchAIL-Bonus', fontsize=title_size, pad=titel_pad)
    # ax[1].legend(loc='center right', ncol=1, frameon=False, fontsize="x-large", bbox_to_anchor=(1.9, 0.7))
    ax[4].legend(loc='center right', ncol=1, frameon=False, fontsize="x-large", bbox_to_anchor=(2.1, 0.7))

    plt.savefig("augment_ablation_{}.png".format(traj_num), bbox_inches='tight')
    plt.savefig("augment_ablation_{}.pdf".format(traj_num), bbox_inches='tight')

def draw_similarity(traj_num=10):
    envs = ['dmc_cartpole_swingup', 'dmc_finger_spin'] 
    methods = ['PatchAIL-W eq(7)', 'PatchAIL-B eq(7)', 'PatchAIL-W eq(6)', 'PatchAIL-B eq(6)']
    names = [ 'patchairl_ss_kldsimregauto1.3_buf100w_randomexpdemo_noaugsim', 'patchairl_ss_kldsimregaugauto0.5_buf100w_randomexpdemo_noaugsim',
        'patchairl_ss_kldsimregauto1.3_buf100w_minsim', 'patchairl_ss_kldsimregaugauto0.5_buf100w_minsim_rewscale0.5']
    num_method = len(methods)

    fig_row, fig_col = 2, len(envs)
    fig, ax = plt.subplots(fig_row, fig_col, figsize=(6*fig_col, 4*fig_row))
    plt.subplots_adjust(hspace=0.6, wspace=0.3)
    
    window_length = 2

    for f_i in [0]:
        for f_j in range(fig_col):
            # if f_i*fig_col+f_j >= len(envs):
            #     ax[f_i][f_j].set_visible(False)
            #     continue
            env = envs[f_i*fig_col+f_j]
            require_suffix = None
            expert_performace = expert_performaces[env]
            max_x = 0
            x_axis = None
            frame_limit = 1600000

            for i in range(num_method):
                name = names[i]+"_{}".format(traj_num)
            
                returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix, frame_limit=frame_limit)
                returns = np.array(returns).T
                
                y1 = r_mins(returns)
                y2 = r_maxs(returns)
                y1[np.isnan(y1)] = 0.0
                y2[np.isnan(y2)] = 0.0
                ma_y1 = MA(y1, window_length)
                ma_y2 = MA(y2, window_length)
                ax[0,f_i*fig_col+f_j].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[i], interpolate=True, alpha=alpha)
                if "eq(7)" not in methods[i]:
                    ax[0,f_i*fig_col+f_j].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i], linestyle='dashed')
                else:
                    ax[0,f_i*fig_col+f_j].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i])

                # for j in range(returns.shape[1]):
                #     ax.plot(x, returns[:,j], color=colors[i], linewidth=line_width, alpha=0.2)

            y = np.ones((frame_limit,)) * expert_performace[0]
            ma_y2 = y + expert_performace[1]
            ma_y1 = y - expert_performace[1]
            x = np.arange(frame_limit)
            ax[0,f_i*fig_col+f_j].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[-1], interpolate=True, alpha=alpha)
            ax[0,f_i*fig_col+f_j].plot(x, y, color=colors[-1], linewidth=line_width, label="Expert", linestyle='dashed')

            ax[0,f_i*fig_col+f_j].set_title(env2name(env), fontsize=title_size, pad=titel_pad)
            ax[0,f_i*fig_col+f_j].set_xlabel('Frames', fontsize=lable_size)
            ax[0,f_i*fig_col+f_j].set_ylabel('Episode Return', fontsize=lable_size)
            ax[0,f_i*fig_col+f_j].ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

    ax[0,1].legend(loc='upper right', ncol=1, frameon=False, fontsize="x-large", bbox_to_anchor=(1.8, 1.0))
    
    for f_i in [0]:
        for f_j in range(fig_col):
            # if f_i*fig_col+f_j >= len(envs):
            #     ax[f_i][f_j].set_visible(False)
            #     continue
            env = envs[f_i*fig_col+f_j]
            require_suffix = None
            expert_performace = expert_performaces[env]
            max_x = 0
            x_axis = None
            frame_limit = 1600000

            for i in range(num_method):
                name = names[i]+"_{}".format(traj_num)
            
                returns, x = read_multi_files('exp_local/'+env+'/'+name, require_suffix, frame_limit=frame_limit, column='similarity', readfrom='train')
                returns = np.array(returns).T
                
                y1 = r_mins(returns)
                y2 = r_maxs(returns)
                y1[np.isnan(y1)] = 0.0
                y2[np.isnan(y2)] = 0.0
                ma_y1 = MA(y1, window_length)
                ma_y2 = MA(y2, window_length)
                ax[1, f_i*fig_col+f_j].fill_between(x, ma_y1, ma_y2, where=ma_y2 >= ma_y1, facecolor=colors[i], interpolate=True, alpha=alpha)
                if "eq(7)" not in methods[i]:
                    ax[1, f_i*fig_col+f_j].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i], linestyle='dashed')
                else:
                    ax[1, f_i*fig_col+f_j].plot(x, MA(r_means(returns), window_length), color=colors[i], linewidth=line_width, label=methods[i])

                # for j in range(returns.shape[1]):
                #     ax.plot(x, returns[:,j], color=colors[i], linewidth=line_width, alpha=0.2)

            ax[1, f_i*fig_col+f_j].set_title(env2name(env), fontsize=title_size, pad=titel_pad)
            ax[1, f_i*fig_col+f_j].set_xlabel('Frames', fontsize=lable_size)
            ax[1, f_i*fig_col+f_j].set_ylabel('Averaged Batch Similarity', fontsize=lable_size-6)
            ax[1, f_i*fig_col+f_j].ticklabel_format(style='sci', scilimits=(-1,2), axis='x')

    plt.savefig("similarity_{}.png".format(traj_num), bbox_inches='tight')
    plt.savefig("similarity_{}.pdf".format(traj_num), bbox_inches='tight')


if __name__ == '__main__':
    # draw_motivation(traj_num=10)
    # draw_motivation_single(traj_num=10)
    # draw_main(traj_num=10)
    # draw_patch_ablation(traj_num=10)
    # draw_aggr_ablation(traj_num=10)
    draw_bcaug(traj_num=1)
    # draw_aug_ablation(traj_num=10)
    draw_similarity(traj_num=10)
    draw_similarity(traj_num=5)

    # draw_main(traj_num=5)
    # draw_main(traj_num=1)
    # draw_observation_choice(traj_num=10)
    # draw_observation_choice(traj_num=5)
    # draw_observation_choice(traj_num=1)
    # draw_observation_choice_in_one(traj_num=10)
    # draw_observation_choice_in_one(traj_num=5)
    # draw_observation_choice_in_one(traj_num=1)
    # draw_noisy(traj_num=10)
    # reliable_plot_noisy(traj_num=10)
    # reliable_plot(traj_num=10)